{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "DDPG.ipynb",
      "version": "0.3.2",
      "provenance": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    }
  },
  "cells": [
    {
      "cell_type": "code",
      "metadata": {
        "id": "YJ6-cbagZ5CT",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "import torch.autograd as autograd\n",
        "import torch.optim as optim\n",
        "\n",
        "import gym\n",
        "import random\n",
        "import numpy as np\n",
        "from collections import deque"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "DCYDiWHiZ8Gp",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def mini_batch_train(env, agent, max_episodes, max_steps, batch_size):\n",
        "    episode_rewards = []\n",
        "\n",
        "    for episode in range(max_episodes):\n",
        "        state = env.reset()\n",
        "        episode_reward = 0\n",
        "        \n",
        "        for step in range(max_steps):\n",
        "            action = agent.get_action(state)\n",
        "            next_state, reward, done, _ = env.step(action)\n",
        "            agent.replay_buffer.push(state, action, reward, next_state, done)\n",
        "            episode_reward += reward\n",
        "\n",
        "            if len(agent.replay_buffer) > batch_size:\n",
        "                agent.update(batch_size)   \n",
        "\n",
        "            if done or step == max_steps-1:\n",
        "                episode_rewards.append(episode_reward)\n",
        "                print(\"Episode \" + str(episode) + \": \" + str(episode_reward))\n",
        "                break\n",
        "\n",
        "            state = next_state\n",
        "\n",
        "    return episode_rewards"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Rz0bMnzdaJyN",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "class BasicBuffer:\n",
        "\n",
        "  def __init__(self, max_size):\n",
        "      self.max_size = max_size\n",
        "      self.buffer = deque(maxlen=max_size)\n",
        "\n",
        "  def push(self, state, action, reward, next_state, done):\n",
        "      experience = (state, action, np.array([reward]), next_state, done)\n",
        "      self.buffer.append(experience)\n",
        "\n",
        "  def sample(self, batch_size):\n",
        "      state_batch = []\n",
        "      action_batch = []\n",
        "      reward_batch = []\n",
        "      next_state_batch = []\n",
        "      done_batch = []\n",
        "\n",
        "      batch = random.sample(self.buffer, batch_size)\n",
        "\n",
        "      for experience in batch:\n",
        "          state, action, reward, next_state, done = experience\n",
        "          state_batch.append(state)\n",
        "          action_batch.append(action)\n",
        "          reward_batch.append(reward)\n",
        "          next_state_batch.append(next_state)\n",
        "          done_batch.append(done)\n",
        "\n",
        "      return (state_batch, action_batch, reward_batch, next_state_batch, done_batch)\n",
        "\n",
        "  def __len__(self):\n",
        "      return len(self.buffer)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "USTdUdkGaWmg",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# Ornstein-Ulhenbeck Noise\n",
        "# Taken from #https://github.com/vitchyr/rlkit/blob/master/rlkit/exploration_strategies/ou_strategy.py\n",
        "class OUNoise(object):\n",
        "    def __init__(self, action_space, mu=0.0, theta=0.15, max_sigma=0.3, min_sigma=0.3, decay_period=100000):\n",
        "        self.mu           = mu\n",
        "        self.theta        = theta\n",
        "        self.sigma        = max_sigma\n",
        "        self.max_sigma    = max_sigma\n",
        "        self.min_sigma    = min_sigma\n",
        "        self.decay_period = decay_period\n",
        "        self.action_dim   = action_space.shape[0]\n",
        "        self.low          = action_space.low\n",
        "        self.high         = action_space.high\n",
        "        self.reset()\n",
        "        \n",
        "    def reset(self):\n",
        "        self.state = np.ones(self.action_dim) * self.mu\n",
        "        \n",
        "    def evolve_state(self):\n",
        "        x  = self.state\n",
        "        dx = self.theta * (self.mu - x) + self.sigma * np.random.randn(self.action_dim)\n",
        "        self.state = x + dx\n",
        "        return self.state\n",
        "    \n",
        "    def get_action(self, action, t=0):\n",
        "        ou_state = self.evolve_state()\n",
        "        self.sigma = self.max_sigma - (self.max_sigma - self.min_sigma) * min(1.0, t / self.decay_period)\n",
        "        return np.clip(action + ou_state, self.low, self.high)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "wYw7DI5-aAiw",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "class Critic(nn.Module):\n",
        "\n",
        "    def __init__(self, obs_dim, action_dim):\n",
        "        super(Critic, self).__init__()\n",
        "\n",
        "        self.obs_dim = obs_dim\n",
        "        self.action_dim = action_dim\n",
        "\n",
        "        self.linear1 = nn.Linear(self.obs_dim, 1024)\n",
        "        self.linear2 = nn.Linear(1024 + self.action_dim, 512)\n",
        "        self.linear3 = nn.Linear(512, 300)\n",
        "        self.linear4 = nn.Linear(300, 1)\n",
        "\n",
        "    def forward(self, x, a):\n",
        "        x = F.relu(self.linear1(x))\n",
        "        xa_cat = torch.cat([x,a], 1)\n",
        "        xa = F.relu(self.linear2(xa_cat))\n",
        "        xa = F.relu(self.linear3(xa))\n",
        "        qval = self.linear4(xa)\n",
        "\n",
        "        return qval\n",
        "\n",
        "class Actor(nn.Module):\n",
        "\n",
        "    def __init__(self, obs_dim, action_dim):\n",
        "        super(Actor, self).__init__()\n",
        "\n",
        "        self.obs_dim = obs_dim\n",
        "        self.action_dim = action_dim\n",
        "\n",
        "        self.linear1 = nn.Linear(self.obs_dim, 512)\n",
        "        self.linear2 = nn.Linear(512, 128)\n",
        "        self.linear3 = nn.Linear(128, self.action_dim)\n",
        "\n",
        "    def forward(self, obs):\n",
        "        x = F.relu(self.linear1(obs))\n",
        "        x = F.relu(self.linear2(x))\n",
        "        x = torch.tanh(self.linear3(x))\n",
        "\n",
        "        return x"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "gADalF9XaAkv",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "class DDPGAgent:\n",
        "    \n",
        "    def __init__(self, env, gamma, tau, buffer_maxlen, critic_learning_rate, actor_learning_rate):\n",
        "        self.device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "        \n",
        "        self.env = env\n",
        "        self.obs_dim = env.observation_space.shape[0]\n",
        "        self.action_dim = env.action_space.shape[0]\n",
        "        \n",
        "        # hyperparameters\n",
        "        self.env = env\n",
        "        self.gamma = gamma\n",
        "        self.tau = tau\n",
        "        \n",
        "        # initialize actor and critic networks\n",
        "        self.critic = Critic(self.obs_dim, self.action_dim).to(self.device)\n",
        "        self.critic_target = Critic(self.obs_dim, self.action_dim).to(self.device)\n",
        "        \n",
        "        self.actor = Actor(self.obs_dim, self.action_dim).to(self.device)\n",
        "        self.actor_target = Actor(self.obs_dim, self.action_dim).to(self.device)\n",
        "    \n",
        "        # Copy critic target parameters\n",
        "        for target_param, param in zip(self.critic_target.parameters(), self.critic.parameters()):\n",
        "            target_param.data.copy_(param.data)\n",
        "        \n",
        "        # optimizers\n",
        "        self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=critic_learning_rate)\n",
        "        self.actor_optimizer  = optim.Adam(self.actor.parameters(), lr=actor_learning_rate)\n",
        "    \n",
        "        self.replay_buffer = BasicBuffer(buffer_maxlen)        \n",
        "        self.noise = OUNoise(self.env.action_space)\n",
        "        \n",
        "    def get_action(self, obs):\n",
        "        state = torch.FloatTensor(obs).unsqueeze(0).to(self.device)\n",
        "        action = self.actor.forward(state)\n",
        "        action = action.squeeze(0).cpu().detach().numpy()\n",
        "\n",
        "        return action\n",
        "    \n",
        "    def update(self, batch_size):\n",
        "        states, actions, rewards, next_states, _ = self.replay_buffer.sample(batch_size)\n",
        "        state_batch, action_batch, reward_batch, next_state_batch, masks = self.replay_buffer.sample(batch_size)\n",
        "        state_batch = torch.FloatTensor(state_batch).to(self.device)\n",
        "        action_batch = torch.FloatTensor(action_batch).to(self.device)\n",
        "        reward_batch = torch.FloatTensor(reward_batch).to(self.device)\n",
        "        next_state_batch = torch.FloatTensor(next_state_batch).to(self.device)\n",
        "        masks = torch.FloatTensor(masks).to(self.device)\n",
        "   \n",
        "        curr_Q = self.critic.forward(state_batch, action_batch)\n",
        "        next_actions = self.actor_target.forward(next_state_batch)\n",
        "        next_Q = self.critic_target.forward(next_state_batch, next_actions.detach())\n",
        "        expected_Q = reward_batch + self.gamma * next_Q\n",
        "        \n",
        "        # update critic\n",
        "        q_loss = F.mse_loss(curr_Q, expected_Q.detach())\n",
        "\n",
        "        self.critic_optimizer.zero_grad()\n",
        "        q_loss.backward() \n",
        "        self.critic_optimizer.step()\n",
        "\n",
        "        # update actor\n",
        "        policy_loss = -self.critic.forward(state_batch, self.actor.forward(state_batch)).mean()\n",
        "        \n",
        "        self.actor_optimizer.zero_grad()\n",
        "        policy_loss.backward()\n",
        "        self.actor_optimizer.step()\n",
        "\n",
        "        # update target networks \n",
        "        for target_param, param in zip(self.actor_target.parameters(), self.actor.parameters()):\n",
        "            target_param.data.copy_(param.data * self.tau + target_param.data * (1.0 - self.tau))\n",
        "       \n",
        "        for target_param, param in zip(self.critic_target.parameters(), self.critic.parameters()):\n",
        "            target_param.data.copy_(param.data * self.tau + target_param.data * (1.0 - self.tau))"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ifOTkUA4aRkt",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "env = gym.make(\"Pendulum-v0\")\n",
        "\n",
        "max_episodes = 100\n",
        "max_steps = 500\n",
        "batch_size = 32\n",
        "\n",
        "gamma = 0.99\n",
        "tau = 1e-2\n",
        "buffer_maxlen = 100000\n",
        "critic_lr = 1e-3\n",
        "actor_lr = 1e-3\n",
        "\n",
        "agent = DDPGAgent(env, gamma, tau, buffer_maxlen, critic_lr, actor_lr)\n",
        "episode_rewards = mini_batch_train(env, agent, max_episodes, max_steps, batch_size)"
      ],
      "execution_count": 0,
      "outputs": []
    }
  ]
}